{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/data/hanweiguang/Projects/PIXIU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils import MultiClient\n",
    "from sklearn.metrics import confusion_matrix, matthews_corrcoef, f1_score, accuracy_score\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "GENERATION_CONFIG = [\n",
    "    0.1,  # int | float (numeric value between 0 and 1) in 'Temperature' Slider component\n",
    "    0.75,  # int | float (numeric value between 0 and 1) in 'Top p' Slider component\n",
    "    40,  # int | float (numeric value between 0 and 100) in 'Top k' Slider component\n",
    "    1,  # int | float (numeric value between 1 and 4) in 'Beams Number' Slider component\n",
    "    True, # do sample\n",
    "    8,  # int | float (numeric value between 1 and 2000) in 'Max New Tokens' Slider component\n",
    "    1,  # int | float (numeric value between 1 and 300) in 'Min New Tokens' Slider component\n",
    "    1.2,  # int | float (numeric value between 1.0 and 2.0) in 'Repetition Penalty' Slider component\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"/data/hanweiguang/Projects/PIXIU/data/cikm18/test.jsonl\") as f:\n",
    "    data = f.readlines()\n",
    "    data = [json.loads(val) for val in data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded as API: http://127.0.0.1:17860/ ✔\n"
     ]
    }
   ],
   "source": [
    "worker_addrs = [\n",
    "    f\"http://127.0.0.1:{17860 + i}\" for i in range(1)\n",
    "]\n",
    "clients = MultiClient(worker_addrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [03:21<00:00,  1.01s/it]\n"
     ]
    }
   ],
   "source": [
    "results = clients.predict(\n",
    "    [\n",
    "        [\n",
    "            datum[\"conversations\"][0][\"value\"]\n",
    "        ] + GENERATION_CONFIG for datum in data[:200]\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [\n",
    "    datum[\"label\"] for datum in data[:200]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_true = [1 if i == \"Rise\" else 0 for i in labels]\n",
    "y_pred = [1 if i == \"Rise\" else 0 for i in results]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MCC: -0.05380001385625025\n"
     ]
    }
   ],
   "source": [
    "# Calculate confusion matrix\n",
    "tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()\n",
    "\n",
    "# Calculate Matthews correlation coefficient\n",
    "mcc = matthews_corrcoef(y_true, y_pred)\n",
    "print(f'MCC: {mcc}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1: 0.4573069852941177\n"
     ]
    }
   ],
   "source": [
    "f1 = f1_score(y_true, y_pred, average='weighted')\n",
    "print(f'F1: {f1}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy: 0.51\n"
     ]
    }
   ],
   "source": [
    "accuracy = accuracy_score(y_true, y_pred)\n",
    "print(f'accuracy: {accuracy}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
